{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Select"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5b3a95f92a7642f48f2ec6c5d60cdcf6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(Dropdown(description='Model:', index=2, options=('MemSeg-leather', 'MemSeg-carpet', 'MemSeg-cap…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "04ac53e33a5f4c838981b734de75e76d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from data import create_dataset, create_dataloader\n",
    "\n",
    "import os\n",
    "import torch\n",
    "import yaml\n",
    "from timm import create_model\n",
    "from models import MemSeg\n",
    "\n",
    "import ipywidgets as widgets\n",
    "from IPython.display import display, clear_output\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "cfg = yaml.load(open('./configs/capsule.yaml','r'), Loader=yaml.FullLoader)\n",
    "\n",
    "# ====================================\n",
    "# Select Model\n",
    "# ====================================\n",
    "\n",
    "def load_model(model_name):\n",
    "    global model\n",
    "    global testset \n",
    "    \n",
    "    testset = create_dataset(\n",
    "        datadir                = cfg['DATASET']['datadir'],\n",
    "        target                 = model_name.split('-')[1], \n",
    "        train                  = False,\n",
    "        resize                 = cfg['DATASET']['resize'],\n",
    "        texture_source_dir     = cfg['DATASET']['texture_source_dir'],\n",
    "        structure_grid_size    = cfg['DATASET']['structure_grid_size'],\n",
    "        transparency_range     = cfg['DATASET']['transparency_range'],\n",
    "        perlin_scale           = cfg['DATASET']['perlin_scale'], \n",
    "        min_perlin_scale       = cfg['DATASET']['min_perlin_scale'], \n",
    "        perlin_noise_threshold = cfg['DATASET']['perlin_noise_threshold']\n",
    "    )\n",
    "    \n",
    "    memory_bank = torch.load(f'./saved_model/{model_name}/memory_bank.pt')\n",
    "    memory_bank.device = 'cpu'\n",
    "    for k in memory_bank.memory_information.keys():\n",
    "        memory_bank.memory_information[k] = memory_bank.memory_information[k].cpu()\n",
    "\n",
    "    feature_extractor = feature_extractor = create_model(\n",
    "        cfg['MODEL']['feature_extractor_name'], \n",
    "        pretrained    = True, \n",
    "        features_only = True\n",
    "    )\n",
    "    model = MemSeg(\n",
    "        memory_bank       = memory_bank,\n",
    "        feature_extractor = feature_extractor\n",
    "    )\n",
    "\n",
    "    model.load_state_dict(torch.load(f'./saved_model/{model_name}/best_model.pt'))\n",
    "\n",
    "# ====================================\n",
    "# Visualization\n",
    "# ====================================\n",
    "\n",
    "\n",
    "def result_plot(idx):\n",
    "    input_i, mask_i, target_i = testset[idx]\n",
    "\n",
    "    output_i = model(input_i.unsqueeze(0)).detach()\n",
    "    output_i = torch.nn.functional.softmax(output_i, dim=1)\n",
    "\n",
    "    def minmax_scaling(img):\n",
    "        return (((img - img.min()) / (img.max() - img.min())) * 255).to(torch.uint8)\n",
    "\n",
    "    fig, ax = plt.subplots(1,4, figsize=(15,10))\n",
    "    \n",
    "    ax[0].imshow(minmax_scaling(input_i.permute(1,2,0)))\n",
    "    ax[0].set_title('Input: {}'.format('Normal' if target_i == 0 else 'Abnormal'))\n",
    "    ax[1].imshow(mask_i, cmap='gray')\n",
    "    ax[1].set_title('Ground Truth')\n",
    "    ax[2].imshow(output_i[0][1], cmap='gray')\n",
    "    ax[2].set_title('Predicted Mask')\n",
    "    ax[3].imshow(minmax_scaling(input_i.permute(1,2,0)), alpha=1)\n",
    "    ax[3].imshow(output_i[0][1], cmap='gray', alpha=0.5)\n",
    "    ax[3].set_title(f'Input X Predicted Mask')\n",
    "    \n",
    "    plt.show()\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# ====================================\n",
    "# widgets\n",
    "# ====================================\n",
    "\n",
    "\n",
    "model_list = widgets.Dropdown(\n",
    "    options=os.listdir('./saved_model'),\n",
    "    value='MemSeg-capsule',\n",
    "    description='Model:',\n",
    "    disabled=False,\n",
    ")\n",
    "button = widgets.Button(description=\"Model Change\")\n",
    "output = widgets.Output()\n",
    "\n",
    "\n",
    "@output.capture()\n",
    "def on_button_clicked(b):\n",
    "    clear_output(wait=True)\n",
    "    load_model(model_name=model_list.value)\n",
    "    \n",
    "    # vizualization\n",
    "    file_list = widgets.Dropdown(\n",
    "        options=[(file_path, i) for i, file_path in enumerate(testset.file_list)],\n",
    "        value=0,\n",
    "        description='image:',\n",
    "    )\n",
    "\n",
    "    \n",
    "    widgets.interact(result_plot, idx=file_list)\n",
    "\n",
    "button.on_click(on_button_clicked)\n",
    "\n",
    "\n",
    "display(widgets.HBox([model_list, button]), output)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
